Skip to content

[MLAS] Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback#28416

Open
melkap01-Arm wants to merge 26 commits into
microsoft:mainfrom
melkap01-Arm:fp8_DynamicQuantMatMul_Support
Open

[MLAS] Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback#28416
melkap01-Arm wants to merge 26 commits into
microsoft:mainfrom
melkap01-Arm:fp8_DynamicQuantMatMul_Support

Conversation

@melkap01-Arm
Copy link
Copy Markdown
Contributor

@melkap01-Arm melkap01-Arm commented May 8, 2026

Description

Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback

This MR adds a CPU contrib implementation for com.microsoft::DynamicQuantMatMulFp8. The path supports dynamic
block-wise quantization of float/float16/bfloat16 activations to FP8, FP8 runtime B, constant non-FP8 B pre-
quantization, block-wise scales, configurable block sizes, and float/float16/bfloat16 outputs.

Main Changes

Adds DynamicQuantMatMulFp8 schema under the Microsoft contrib opset.
Registers the CPU contrib kernel when FP8 types are enabled.
Adds MlasFp8GemmBatch scalar qgemm_fp8.cpp fallback implementation, which performs the FP8 GEMM compute path used by the CPU kernel.
Adds provider tests for the FP8 opkernel path.

Operator Contract

A supports float, float16, and bfloat16.
Runtime B supports FP8 only and must be rank-2.
Constant initializer B supports float, float16, bfloat16, or FP8.
Non-FP8 constant B is dynamically quantized once during PrePack.
Dynamic non-FP8 B is intentionally rejected.
Output Y supports float, float16, and bfloat16.
Optional Y_scale and Y_zero_point are supported.

FP8 formats supported:

FLOAT8E4M3FN
FLOAT8E4M3FNUZ
FLOAT8E5M2
FLOAT8E5M2FNUZ

Quantization Semantics

The implementation enforces symmetric quantization.
All provided zero-point inputs must encode 0.0.
Non-zero zero points are rejected.
Scale values are validated as finite and positive before use.
A scales are computed dynamically by the kernel.
For non-FP8 constant B, B scales are computed during PrePack.
For FP8 runtime/constant B, B_scale is required and validated.
Y_scale, when provided, must be scalar and is applied to the final accumulation.
Y_zero_point, when provided, must be scalar and zero-valued.

Block Layout

Adds block_size_m, block_size_k, and block_size_n.
block_size_m defaults to 1 and is currently constrained to 1.
block_size_k and block_size_n default to 128.
A scale layout is row/block-K based and generated internally by the kernel.
B_scale and B_zero_point use [N / block_size_n, K / block_size_k] layout.
Shape inference was tightened to match runtime behavior, including rank-2 B enforcement.

Kernel Behavior

Runtime FP8 B is consumed directly.
Constant non-FP8 B is quantized to FP8 in PrePack.
Constant FP8 B keeps its FP8 type metadata.
Prepacked metadata restores B shape, FP8 type, quantized B size, and B scale count for shared prepack reuse.
B/B-zero-point FP8 type consistency is validated regardless of whether B type came from runtime B or prepack
metadata.
K == 0 produces zero-filled output instead of returning uninitialized data.
M == 0 and N == 0 empty outputs return cleanly after cheap runtime contract validation.

MLAS FP8 Fallback

Adds MlasFp8GemmBatch.
Implements FP8 decode, block-wise scale application, float accumulation, and optional output scaling.
Supports all four FP8 modes listed above.
Parallelizes fallback work over BatchN * M.
Adds defensive validation before threaded execution:

valid FP8 mode
non-zero block sizes
required pointers only when actually dereferenced
leading dimensions only when used
strided offset overflow checks
block scale offset overflow checks
caller-provided block counts must match the GEMM shape and block sizes

This is a functional scalar fallback, not a hardware-optimized FP8 GEMM backend.

Tests

Provider tests cover:

Constant non-FP8 B prepack path
Runtime FP8 B path
All four FP8 formats
Omitted optional output quantization inputs
Optional Y_scale
Float16 and bfloat16 outputs
Bfloat16 scale tensors
Symmetric zero-point rejection for B/Y
FP8 B / B-zero-point type mismatch rejection
Non-default block sizes
Shared prepacked B metadata restore
Shared prepack semantic correctness with different B scales
Rejection of unsupported dynamic non-FP8 B
Runtime B rank > 2 schema/runtime rejection
Malformed B scale shape validation before scale reads
M == 0, N == 0, and K == 0 edge cases
Invalid Y_scale shape, value, and type on the K == 0 path

Known Limitations

Dynamic non-FP8 B is not supported by design.
No packed-B optimized FP8 backend is exposed in this MR.
No KleidiAI FP8 dispatch is included in this path.
MLAS FP8 GEMM is currently correctness-oriented scalar fallback code, not a production performance kernel.
Full MatMul broadcast semantics for batched B are intentionally not supported; runtime/schema validation is
restricted to rank-2 B.

Verification

Built onnxruntime_provider_test.
Ran FP8 provider tests successfully.
Ran the converted Qwen3 ONNX model successfully.
All DynamicQuantMatMulFp8 tests passed.

Motivation and Context

Adds DynamicQuantMatMulFp8 schema under the Microsoft contrib opset.
Registers the CPU contrib kernel when FP8 types are enabled.
Adds dynamic_quant_matmul_fp8.{h,cc} CPU kernel implementation.
Adds MLAS FP8 GEMM API surface and scalar fallback implementation in qgemm_fp8.cpp.
Wires the MLAS FP8 source into the MLAS build.
Adds provider tests for the FP8 op-kernel path.

Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
…tion enforced

Signed-off-by: melkap01 <melike.kaptan@arm.com>
@melkap01-Arm melkap01-Arm marked this pull request as ready for review May 11, 2026 14:58
@hariharans29 hariharans29 requested a review from Copilot May 11, 2026 17:09
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds a CPU contrib implementation for com.microsoft::DynamicQuantMatMulFp8 backed by a new MLAS FP8 GEMM API, including a scalar fallback implementation and provider/MLAS unit tests.

Changes:

  • Introduces DynamicQuantMatMulFp8 schema, CPU kernel registration, and a CPU opkernel implementation with prepack support for constant non-FP8 B.
  • Adds MLAS FP8 GEMM public API (MlasFp8GemmBatch) and scalar fallback implementation, plus a shared size_t overflow helper.
  • Adds provider tests for the new contrib op and MLAS unit tests for the FP8 GEMM path.

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp Adds MLAS unit tests for the FP8 GEMM batch API (threaded + edge cases).
onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc Adds CPU provider tests covering op contract, prepack, scale/zero-point validation, and edge cases.
onnxruntime/core/mlas/lib/qgemm_fp8.cpp Implements scalar fallback for MlasFp8GemmBatch with validation and parallelism.
onnxruntime/core/mlas/lib/mlasi.h Adds MlasMultiplyOverflowsSizeT helper used for overflow-safe size computations.
onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp Switches overflow checks to the shared MlasMultiplyOverflowsSizeT helper.
onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp Switches overflow checks to the shared MlasMultiplyOverflowsSizeT helper.
onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h Renames a parameter and removes the local overflow helper in favor of shared MLAS helper.
onnxruntime/core/mlas/inc/mlas.h Adds public MLAS FP8 GEMM structs/API and the FP8 mode enum.
onnxruntime/core/graph/contrib_ops/quantization_defs.cc Adds the DynamicQuantMatMulFp8 contrib operator schema + shape inference.
onnxruntime/core/graph/contrib_ops/ms_opset.h Registers the new contrib schema in the Microsoft opset (gated on float8 support).
onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h Declares the new CPU contrib kernel with prepack/shared-prepack support.
onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc Implements the CPU kernel, including PrePack quantization of constant B and MLAS dispatch.
onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc Registers the CPU kernel when float8 types are enabled.
cmake/onnxruntime_mlas.cmake Wires qgemm_fp8.cpp into the MLAS build.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/mlas/lib/qgemm_fp8.cpp Outdated
Comment thread onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc Outdated
Comment thread onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc Outdated
Comment thread onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/quantization_defs.cc Outdated
Comment thread onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc
Comment thread onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc
Signed-off-by: melkap01 <melike.kaptan@arm.com>
…es implemented

Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
@melkap01-Arm melkap01-Arm marked this pull request as draft May 13, 2026 15:50
Signed-off-by: melkap01 <melike.kaptan@arm.com>
@melkap01-Arm melkap01-Arm marked this pull request as ready for review May 13, 2026 20:59
@hariharans29 hariharans29 changed the title Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback [MLAS] Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback May 14, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 1 comment.

Comment thread onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc Outdated
@melkap01-Arm melkap01-Arm marked this pull request as draft May 20, 2026 15:00
Signed-off-by: melkap01 <melike.kaptan@arm.com>
@melkap01-Arm melkap01-Arm marked this pull request as ready for review May 20, 2026 22:22
@hariharans29 hariharans29 requested a review from Copilot May 20, 2026 22:24
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 3 comments.

Comment thread onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/quantization_defs.cc
Comment thread onnxruntime/core/mlas/inc/mlas.h Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 2 comments.

Comment thread onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/quantization_defs.cc
@hariharans29
Copy link
Copy Markdown
Member

Review — PR #28416: [MLAS] Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback

Big surface area: new contrib op + schema + shape inference + prepack + a fresh public MLAS API (MlasFp8GemmBatch) + scalar fallback + tests + an unrelated MLAS helper refactor (MlasMultiplyOverflowsSizeT). ~2.4k lines. Functionally it looks correct and the test coverage is genuinely good, but I have substantive concerns about the framing of this as an "MLAS FP8" path and about a few API/perf decisions that will be painful to walk back once shipped.

Verdict: needs work before merge. Not because of correctness, but because (a) the MLAS public API being introduced is shipped without any optimized backend behind it, (b) the data-params struct will be ABI-fragile, and (c) several validations get redundantly executed on every Compute().


Substantive concerns

1. MlasFp8GemmBatch is a scalar reference implementation behind a permanent public ABI.
qgemm_fp8.cpp is O(M·N·K) with FP8 decode (Fp8ByteToFloat switch) on the inner element, no register blocking, no SIMD, no packed B. The PR description acknowledges this ("correctness-oriented scalar fallback, not a production performance kernel") — but it lands in mlas.h as MlasFp8GemmBatch, which downstream code will reasonably assume is the MLAS performance path. There is no KleidiAI dispatch added, no x86 dispatch, no dispatch struct at all; today the API has exactly one implementation and that implementation is a triple loop. A user converting Qwen3 to this op on CPU will get unacceptable throughput and have no diagnostic to indicate "this is the reference path." Two options:

  • Rename to MlasReferenceFp8GemmBatch (or keep it internal to the contrib op for now) until at least one optimized backend lands.
  • Or land a KleidiAI FP8 dispatch in the same MR so the public API has a real reason to exist.

The current framing risks normalizing a slow path under an MLAS name.

2. MLAS_FP8_GEMM_DATA_PARAMS will be ABI-fragile.
14 fields, Fp8Type is per-batch-element (genuinely unusual — why would batch entries vary mode?), BlocksM/K/N are caller-provided and re-derived/validated inside CheckFp8GemmBatchParams, and the strides (ScaleAStrideM, ScaleAStrideK, ScaleBStrideK, ScaleBStrideN) are also caller-provided. Future variants (per-channel B, per-tensor A, group quant) will require new fields, which is breaking for statically linked consumers. Suggest either (a) hide most of these behind a MLAS_FP8_GEMM_BLOCK_LAYOUT sub-struct, or (b) commit to a versioned config pattern now.

Also: Fp8Type per DataParams but Shape is shared — if a caller varies Fp8Type per batch, the kernel implementation cannot precompute mode-specific decode tables. Make it part of MLAS_FP8_GEMM_SHAPE_PARAMS.

3. mlas_fp8_mode naming and provenance.
Lowercase snake_case enum in a public MLAS header breaks convention (rest of MLAS is MLAS_FOO). The comment "FP8 mode is aligned with Arm KleidiAI format/overflow modes. Defined here to keep MLAS FP8 APIs platform-agnostic" then leaks the implementation alignment into the public name. Rename to MLAS_FP8_MODE and drop the KleidiAI reference.

4. Per-Compute() redundant validation.
ValidateZeroPointValuesAreZero walks the full B zero-point tensor every Run() call, even when B (and therefore the zero-point) is a constant initializer. Same for ValidatePositiveFiniteScales(b_scales, …) when scales came from a constant initializer. In a serving scenario these checks burn CPU on every inference. Either:

  • Validate constant inputs once at OpKernel construction / PrePack and cache, or
  • Drop the zero-point inputs entirely from the schema since the op is symmetric-only and the kernel rejects non-zero values anyway.

The latter is cleaner: an op whose contract is "all zero points must encode 0.0" probably shouldn't accept zero-point inputs at all. Make symmetric explicit in the schema and remove the input slots.

5. block_size_m attribute that must be 1.
The schema exposes block_size_m, the kernel rejects everything except 1, and shape inference enforces 1. This is a future-compatibility footgun: once a model is in the wild with block_size_m=1 as a baked attribute, you cannot change the default semantics. Either drop the attribute now (the op is row-wise-A by contract) or actually implement block_size_m > 1.

6. Per-Compute() allocations from the temp allocator.
Every Run() allocates: A FP8 buffer (M*K*batch), A scale buffer, optional Y float scratch (when output is fp16/bf16), optional B-scale float-conversion buffer. For a 70B-class model with thousands of MatMul calls per token, this is meaningful allocator churn. At minimum, the A FP8 + A scale pair should be sized once and reused, or you should rely on the arena allocator — but note the bug report I just looked at (#28654) shows the arena being disabled in practice by users, so don't over-rely on it.

7. kPackedBMetadataHasFp8ModeIndex is dead weight.
PrePack always writes 1 to this slot; RestorePackedBMetadata rejects anything else. Drop the field and bump kPackedBMetadataVersion if you want to preserve forward compatibility.

8. b_scales_ ownership gymnastics.
The IAllocatorUniquePtr<float>IAllocatorUniquePtr<void> round-trip via lambda-wrapped deleters in both PrePack and UseSharedPrePackedBuffers is fragile (get_deleter() is moved-from before release() and the assumption that the deleter stays valid is non-obvious). Either store B scales as IAllocatorUniquePtr<void> + element count throughout, or add a small static_cast helper that does this conversion in one place with a comment about why it's safe.

9. Test Fp8Gemm, ZeroColumnReturnsBeforeWorkItemOverflow doesn't actually test what it claims.
BatchN=2, M=SIZE_MAX, N=0 — the BatchN==0 || M==0 || N==0 early return triggers on N==0 before CheckedWorkItems runs, so the overflow path is never exercised. If you want to exercise the overflow guard, you need N != 0, but then ORT_ENFORCE aborts the process. Either:

  • Convert CheckedWorkItems and CheckFp8GemmBatchParams to return Status (preferred — ORT_ENFORCE in MLAS code that's reachable from user-controlled shapes is a DoS vector), or
  • Delete the misleading test and rely on the contrib-op-level guards.

10. MlasMultiplyOverflowsSizeT refactor is unrelated.
The migration from kleidiai/mlasi_kleidiai.h to mlasi.h is a sensible cleanup but belongs in its own MR. Mixing it in here makes the diff harder to review and the revert story worse if something regresses. Split.


Nits

  • core/common/fp8_common.h carries only Arm copyright but is a generic core utility. Consider Microsoft co-attribution per the repo's usual practice for core/ files. (Doesn't block merge.)
  • Fp8ByteToFloat's default: return 0.0f; swallows invalid modes silently. Prefer ORT_THROW or __builtin_unreachable() paired with a debug assert — it's only called after IsValidFp8Mode checks anyway, but the silent zero is the wrong fallback.
  • QuantizeBlockwiseFp8ABlockDynamic and QuantizeBlockwiseFp8WithScales each iterate src[row_offset + k] twice (once for max_abs, once for quantize). For bf16/fp16 inputs that means two casts per element. Minor, but the comment says "match KleidiAI" — the layout matches, the perf does not.
  • gemm_data.ScaleAStrideK = 1; gemm_data.ScaleAStrideM = blocks_k; and the symmetric B strides are caller-derived and re-validated in MLAS — pick one source of truth.
  • Schema lists Inputs (2 - 6) but documents 6 inputs. The optional-input chaining (AddOptionalInputEdge between B_zero_point and Y_scale) is correct but easy for model authors to get wrong. A short example in the schema doc would help.
  • Y_zero_point provided without Y_scale is silently accepted. Either enforce both-or-neither at the schema level or document explicitly that Y_zero_point requires Y_scale.
  • Tests use 1e-5f absolute tolerance for some cases and 0.5f for others — the 0.5f is reasonable for FP8 quantization noise but is an enormous tolerance; consider asserting a relative error bound instead so silent precision regressions surface.

Bottom line

Land the contrib op, but split the work:

  1. This MR: contrib op + schema + tests + an internal (un-exported) scalar FP8 GEMM helper used only by this kernel.
  2. Follow-up MR: promote MlasFp8GemmBatch to mlas.h together with at least one optimized backend (KleidiAI SME path is the obvious candidate). At that point lock down the DATA_PARAMS struct shape.
  3. Separate MR: the MlasMultiplyOverflowsSizeT refactor.

If splitting is not on the table, then before merge please at minimum: (a) rename the public symbol to make the reference nature explicit, (b) move per-Run() constant-input validation into PrePack, (c) drop the dead kPackedBMetadataHasFp8ModeIndex slot, (d) drop block_size_m attribute or implement it, (e) convert MLAS-side ORT_ENFORCE to Status, (f) rename mlas_fp8_modeMLAS_FP8_MODE and remove the KleidiAI reference from the public header.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants